'''
Author: SlytherinGe
LastEditTime: 2021-10-30 15:35:20
'''
import mmcv
import torch
import numpy as np
import matplotlib.pyplot as plt
# (x, y) pos
TEST_POS = (100, 100)
IMG_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/'
IMG_ID = '000230'
def get_similarity_map(test_pos, img_id):
    # load DFL vector
    DFL_vec = mmcv.load('/home/gejunyao/ramdisk/DFL_Cache/'+ img_id + '.pkl')[0]
    # transfer DFL from (vec, h, w) to (h, w, vec)
    vec_len, h, w = DFL_vec.shape[0], DFL_vec.shape[1], DFL_vec.shape[2]
    DFL_vec = DFL_vec.permute(1, 2, 0).reshape(-1, vec_len)
    # normailize DFL vec
    norm_DFL_vec = DFL_vec #/ \
                   # (torch.norm(DFL_vec, dim=1, keepdim=True) + 1e-9)
    # get test vec
    index = test_pos[0] + test_pos[1]*h
    test_vec = DFL_vec[index]
    vec = torch.norm(norm_DFL_vec-test_vec, dim=1)

    # transform
    # vec = -torch.log10(1 - vec + 0.00001)
    vis_vec = vec.reshape(h, w, 1)

    img = mmcv.imread(IMG_ROOT + img_id + '.jpg')
    rescale_img = mmcv.imrescale(img, (h, w))
    padded_img = mmcv.impad(rescale_img, shape=(h, w))
    vis_vec = vis_vec.numpy()
    # vis_vec /= np.max(vis_vec) + 0.000001
    return vis_vec, padded_img, test_vec


def plt_refresh(distence_map, raw_img, test_vec):
    x = np.arange(distence_map.shape[1])
    y = np.arange(distence_map.shape[0])
    levels = np.array([10,20,30,40,50,60,70,80,90,100,120,140,160,180,200,220,240,260])
    X, Y = np.meshgrid(x, y)
    plt.cla()
    plt.suptitle('test vec val: {}'.format(test_vec))
    plt.subplot(1,2,1)
    plt.imshow(distence_map)
    plt.subplot(1,2,2)
    plt.imshow(raw_img)
    plt.contourf(X, Y, distence_map[:,:,0], levels=40, alpha=0.3)
    C = plt.contour(X, Y, distence_map[:,:,0], levels, cmap=plt.cm.summer)
    plt.clabel(C, inline=True, fontsize=12)
    plt.show()

def on_press(event):
    x, y = event.xdata, event.ydata
    if event.button==3:
        vis_vec, result, test_vec = get_similarity_map((int(x), int(y)), IMG_ID)
        plt_refresh(vis_vec, result, test_vec)


vis_vec, result, _ = get_similarity_map(TEST_POS, IMG_ID)
print(vis_vec.shape)
# visulization
fig = plt.figure(figsize=(20,10))
fig.canvas.mpl_connect("button_press_event", on_press)
plt_refresh(vis_vec, result, None)
